
import torch

def energy_score(dataset_in, 
                 dataset_out,
                 net,
                 device):

    dataset_out_len = len(dataset_out.test_loader.dataset)
    dataset_in_len = len(dataset_in.test_loader.dataset)

    pred = torch.zeros((dataset_in_len + dataset_out_len)).to(device)
    pred_cos = torch.zeros((dataset_in_len + dataset_out_len)).to(device)
    y = torch.zeros_like(pred).to(device)
    index = 0
    datasets = [dataset_in.test_loader, dataset_out.test_loader]
    temp = 1

    fc_w = net.linear.weight.data.clone().detach().to(device)
    fc_b = net.linear.bias.data.clone().detach().to(device)

    fc_w_norm = torch.nn.functional.normalize(fc_w, dim=1)

    with torch.no_grad():
        for dataset_index, dataset in enumerate(datasets):
            for batch_idx, (data, labels) in enumerate(dataset):
                data = data.to(device)
                labels = labels.to(device)

                out, fet = net(data, latent=True)
                energy_score = temp * torch.logsumexp(out / temp, dim=1)

                norm_fet = torch.nn.functional.normalize(fet, dim=1)
                cos_theta = torch.mm(norm_fet, fc_w_norm.T).max(dim=1)[0]
                    
                # Note dataset_index = 0 for In-Dist
                # and dataset_index = 1 for OoD
                pred[index: index + data.shape[0]] = energy_score
                pred_cos[index: index + data.shape[0]] = cos_theta
                y[index: index + data.shape[0]] = torch.ones_like(labels).to(device) * dataset_index
                index += data.shape[0]

    labels = y.cpu().numpy()

    # Make pred between 0 and 1
    pred = (pred - pred.min()) / (pred.max() - pred.min())
    pred = pred + pred_cos
    pred = -pred.cpu().numpy()
    return labels, pred
